%% 2023/02/03
% Load two-particle quantum walk data with effective on-site
% interaction energy = 0, analyze the evolution of the density profiles 
% and correlations to extract signatures of pairing for non-zero
% statistical phase.

% 2023/02/27:
% - Bootstrap analysis to extract errorbars on RMS size estimation

close all
clear


%% Directories containing the data

J = 2*pi * 10.6; % Tunneling amplitude
theta_list = 0:0.25:1; % Statistical phases, in unit of pi
N = numel(theta_list);

common_root = '2023_01_';

% Theta = 0 (2023/01/13)
folder_list{1} = strcat(common_root, '13_4_2D2_qw_n2_6Er_modfreq3_1010_phase0');
col1(1) = 115; col2(1) = 128; row1(1) = 90; row2(1) = 113;

% Theta = 0.25 pi (2023/01/21)
folder_list{2} = strcat(common_root, '21_1_2D2_qw_n2_6Er_modfreq3_1010_phase0p25');
col1(2) = 116; col2(2) = 128; row1(2) = 90; row2(2) = 113;

% Theta = 0.5 pi (2023/01/09)
folder_list{3} = strcat(common_root, '09_1_2D2_qw_n2_6Er_modfreq3_1010_phase0p5');
col1(3) = 115; col2(3) = 128; row1(3) = 90; row2(3) = 113;

% Theta = 0.75 pi (2023/01/23)
folder_list{4} = strcat(common_root, '23_2_2D2_qw_n2_6Er_modfreq3_1010_phase0p75');
col1(4) = 115; col2(4) = 128; row1(4) = 90; row2(4) = 113;

% Theta = pi (2023/01/05)
folder_list{5} = strcat(common_root, '05_1_2D2_qw_n2_6Er_modfreq3_1010_phase1p0');
col1(5) = 118; col2(5) = 128; row1(5) = 90; row2(5) = 113;


%% Other things that can be useful

% For loading raw dwata
prefix = 'scan';
paramVar ='evolution_dur';

% Matlab default colors for plotting
c1 = [0 0.4470 0.7410];
c2 = [0.8500 0.3250 0.0980];
c3 = [0.9290 0.6940 0.1250];

% Affine model for fitting
AffineFun = @(a, b, x) a*x + b;
Affine_Fittype = fittype(AffineFun);

% For bootstrap
N_BS = 100;


%% Load the data
for k = 1:N
    disp(strcat('k = ', 32, num2str(k)))

    folder = folder_list{k};
    [param_names, param_table] = get_batch_params(folder);
    paramInd = find(strcmp(param_names, paramVar));
    M = row2(k) - row1(k) + 1;

    % Time list
    alltvalues = param_table(:, paramInd);
    alltlength = length(alltvalues);
    time = unique(alltvalues);
    N_t = numel(unique(time));

    rms_size_close = zeros(1, N_t);
    err_rms_size_close = zeros(1, N_t);
    rms_size_far = zeros(1, N_t);
    err_rms_size_far = zeros(1, N_t);

    % Doublon detection
    doublonInd = find(strcmp(param_names, 'doublon_ramp_dur'));
    doublon_ramp_dur_all = param_table(:, doublonInd);
    doublon_ramp_dur_list = unique(doublon_ramp_dur_all);

    % Clear and prepare some variables
    corr_diff2_b = [];
    corr_diff0_b = [];
    corr_diff2_a = [];
    corr_diff1_a = [];

    N_a = zeros([1, N_t]);
    N_a_1 = zeros([1, N_t]);
    N_b = zeros([1, N_t]);
    N_b_0 = zeros([1, N_t]);

    tic
    for dd = 1:numel(doublon_ramp_dur_list)
        doublon_ramp_dur = doublon_ramp_dur_list(dd);
    
        for j1 = 1:N_t
            good_index = (alltvalues == time(j1)) & (doublon_ramp_dur_all == doublon_ramp_dur);
            good_index(end) = 0;
            good_index = find(good_index)';

            for j2 = good_index   
                listing = dir(fullfile(folder, [prefix '*' num2str(j2, '%03.f') 'atomMatrix.mat']));
                Nshots = size(listing, 1);
    
                for i = 1:Nshots
                    f = load(fullfile(folder, listing(i).name));
                    
                    for col = col1(k):col2(k)                       
                        if sum(f.atomMatrix(row1(k):row2(k), col)) == 2
                            [row, ~, ~] = find(f.atomMatrix(row1(k):row2(k), col));
                        
                            if doublon_ramp_dur > 0
                                N_b(j1) = N_b(j1) + 1;
                                if diff(row) > 1
                                    corr_diff2_b(j1, row(1), row(2), N_b(j1)) = 1;
                                    corr_diff2_b(j1, row(2), row(1), N_b(j1)) = 1;
                                    corr_diff0_b(j1, 1:M, 1:M, N_b(j1)) = 0;
                                else
                                    corr_diff2_b(j1, 1:M, 1:M, N_b(j1)) = 0;
                                    corr_diff2_b(j1, 1:M, 1:M, N_b(j1)) = 0;
                                    corr_diff0_b(j1, row(2), row(2), N_b(j1)) = 2;
                                end
                            else
                                N_a(j1) = N_a(j1) + 1;
                                if diff(row) > 1
                                    corr_diff2_a(j1, row(1), row(2), N_a(j1)) = 1;
                                    corr_diff2_a(j1, row(2), row(1), N_a(j1)) = 1;
                                    corr_diff1_a(j1, 1:M, 1:M, N_a(j1)) = 0;
                                    corr_diff1_a(j1, 1:M, 1:M, N_a(j1)) = 0;
                                else
                                    corr_diff2_a(j1, 1:M, 1:M, N_a(j1)) = 0;
                                    corr_diff2_a(j1, 1:M, 1:M, N_a(j1)) = 0;
                                    corr_diff1_a(j1, row(1), row(2), N_a(j1)) = 1;
                                    corr_diff1_a(j1, row(2), row(1), N_a(j1)) = 1;
                                end
                            end

                        end
                    end
                end
            end
        end
    end
    toc


    %% Now resample the data for bootstrap analysis

    tic
    parfor j1 = 1:N_t
        rms_size_close_bs = zeros(1, N_BS);
        rms_size_far_bs = zeros(1, N_BS);
        
        for k2 = 1:N_BS
            good_index_a = randsample(1:N_a(j1), N_a(j1), true);
            good_index_b = randsample(1:N_b(j1), N_b(j1), true);
            corr_diff2_b_aux = squeeze(corr_diff2_b(j1, :, :, good_index_b));
            corr_diff0_b_aux = squeeze(corr_diff0_b(j1, :, :, good_index_b));
            corr_diff2_a_aux = squeeze(corr_diff2_a(j1, :, :, good_index_a));
            corr_diff1_a_aux = squeeze(corr_diff1_a(j1, :, :, good_index_a));
            
            N_a_1 = 0;
            for k3 = 1:N_a(j1)
                if max(max(corr_diff1_a_aux(:,:,k3))) > 0
                    N_a_1 = N_a_1 + 1;
                end
            end
            N_b_0 = 0;
            for k3 = 1:N_b(j1)
                if max(max(corr_diff0_b_aux(:,:,k3))) > 0
                    N_b_0 = N_b_0 + 1;
                end
            end

            corr = ( sum(corr_diff2_a_aux, 3) + 2 * sum(corr_diff1_a_aux, 3) + sum(corr_diff2_b_aux, 3) + 2 * sum(corr_diff0_b_aux, 3) ) / (N_a(j1) + N_a_1 + N_b(j1) + N_b_0);


            %% Condition on relative distance

            % Extract size grid
            center = [(M+1)/2, (M+1)/2];
            x_list = 1:M;
            [X1, X2] = meshgrid(x_list, x_list);
                    
            % Condition on distance between particles
            D_close = 2;
            index_close = abs(X1 - X2) <= D_close;
            index_far = ~index_close;
            
            % Resum to get conditional density profiles
            density = sum(corr, 1);
            density_close = sum(corr .* index_close, 1);
            density_far = sum(corr .* index_far, 1);


            %% Compute RMS size
        
            % Close
            mass = sum(density_close);
            x_mean = 1 / mass * sum( density_close .* x_list );
            x_square_mean = 1 / mass * sum( density_close .* x_list.^2 );
            rms_size_close_bs(k2) = sqrt(x_square_mean - x_mean^2);
        
            % Far
            mass = sum(density_far);
            x_mean = 1 / mass * sum( density_far .* x_list );
            x_square_mean = 1 / mass * sum( density_far .* x_list.^2 );
            rms_size_far_bs(k2) = sqrt(x_square_mean - x_mean^2);
        
        end

        rms_size_close(j1) = mean(rms_size_close_bs);
        err_rms_size_close(j1) = std(rms_size_close_bs);
        rms_size_far(j1) = mean(rms_size_far_bs);
        err_rms_size_far(j1) = std(rms_size_far_bs);
    end
    toc


    %% Fit expansion velocity

    good_index = (time > 15);
    data_x = J*10^-3 * time(good_index);
    
    % Close
    data_y = rms_size_close(good_index)';
    weight_data_y = 1 ./ err_rms_size_close(good_index)';
    start_point = [1, 0.5];
    fitresult = fit(data_x, data_y, Affine_Fittype, 'StartPoint', start_point, 'Weight', weight_data_y);
    rms_size_fit_close = fitresult(data_x);
    velocity_close(k) = fitresult.a;
    conf_int = confint(fitresult);
    err_velocity_close(k) = (conf_int(2,1) - conf_int(1,1))/2;
    
    % Far 
    data_y = rms_size_far(good_index)';
    weight_data_y = 1 ./ err_rms_size_far(good_index)';
    start_point = [1, 0.5];
    fitresult = fit(data_x, data_y, Affine_Fittype, 'StartPoint', start_point, 'Weight', weight_data_y);
    rms_size_fit_far = fitresult(data_x);
    velocity_far(k) = fitresult.a;
    conf_int = confint(fitresult);
    err_velocity_far(k) = (conf_int(2,1) - conf_int(1,1))/2;


    %% Plot the results with fit
    
    plot_figure = 1;    
    if plot_figure
        figure()
        hold on

        % Close
        errorbar(time, rms_size_close, err_rms_size_close, 'o', ...
            'Capsize', 0, 'Color', c2, 'LineWidth', 1.5, ...
            'DisplayName', ['d\leq2 (v = ' num2str(velocity_close(k)) ' 1/\tau)'] )
        p2 = plot(data_x, rms_size_fit_close, '--', 'Color', c2, 'LineWidth', 3, 'HandleVisibility', 'off');
        p2.Color(4) = 0.5;

        % Far 
        errorbar(time, rms_size_far, err_rms_size_far, 'o', ...
            'Capsize', 0, 'Color', c3, 'LineWidth', 1.5, ...
            'DisplayName', ['d>2 (v = ' num2str(velocity_far(k)) ' 1/\tau)'])
        p3 = plot(data_x, rms_size_fit_far, '--', 'Color', c3, 'LineWidth', 3, 'HandleVisibility', 'off');
        p3.Color(4) = 0.5;
        
        xlim([15, 60])
        legend('Location', 'Best')
        xlabel('t (ms)')
        ylabel('RMS size (sites)')
        hold off
    end

end


%% Summary plot: Ratio between expansion velocities of close and distant pairs

plot_figure = 1;
if plot_figure
    ratio_velocity = velocity_close ./ velocity_far;
    err_ratio_velocity = sqrt( (err_velocity_close ./ velocity_far).^2 + (velocity_close .* err_velocity_far ./ velocity_far).^2 );

    figure
    hold on
    errorbar(theta_list, ratio_velocity, err_ratio_velocity, 'o', 'Capsize', 0, 'LineWidth', 1.5, 'MarkerSize', 6, 'DisplayName', 'Total') 
    set(gca, 'FontSize', 11);
    xlabel('\theta [\pi]', 'FontSize', 12)
    ylabel('v_{d \leq 2 } / v_{d > 2}', 'FontSize', 12)
    hold off 
    box on
    ylim([0,1.5])
    hold off

    figure
    box on 
    hold on
    errorbar(theta_list, velocity_far, err_velocity_far, 'o', 'Capsize', 0, ...
        'LineWidth', 1.5, 'MarkerSize', 5, 'DisplayName', 'd > 2')
    errorbar(theta_list, velocity_close, err_velocity_close, 'o', 'Capsize', 0, ...
        'LineWidth', 1.5, 'MarkerSize', 5, 'DisplayName', 'd \leq 2')
    legend('Location', 'Best', 'FontSize', 9)
    set(gca, 'FontSize', 11);
    xlabel('\theta [\pi]', 'FontSize', 12)
    ylabel('v', 'FontSize', 12)
    hold off
end
